{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "42103cae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://huggingface.co/datasets/mesolitica/semisupervised-corpus/resolve/main/similarity/shuffled-train.json\n",
    "# !wget https://huggingface.co/datasets/mesolitica/semisupervised-corpus/resolve/main/similarity/shuffled-test.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f21abd15",
   "metadata": {},
   "outputs": [],
   "source": [
    "from bidirectional_mistral import MistralBiModel\n",
    "from transformers import MistralPreTrainedModel\n",
    "import torch\n",
    "import numpy as np\n",
    "import json\n",
    "from typing import Optional, List\n",
    "from torch import nn\n",
    "from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n",
    "from transformers.modeling_outputs import SequenceClassifierOutputWithPast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bd862a54",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MistralForSequenceClassification(MistralPreTrainedModel):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        self.num_labels = config.num_labels\n",
    "        self.model = MistralBiModel(config)\n",
    "        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n",
    "\n",
    "        # Initialize weights and apply final processing\n",
    "        self.post_init()\n",
    "        \n",
    "    def forward(\n",
    "        self,\n",
    "        input_ids: torch.LongTensor = None,\n",
    "        attention_mask: Optional[torch.Tensor] = None,\n",
    "        position_ids: Optional[torch.LongTensor] = None,\n",
    "        past_key_values: Optional[List[torch.FloatTensor]] = None,\n",
    "        inputs_embeds: Optional[torch.FloatTensor] = None,\n",
    "        labels: Optional[torch.LongTensor] = None,\n",
    "        use_cache: Optional[bool] = None,\n",
    "        output_attentions: Optional[bool] = None,\n",
    "        output_hidden_states: Optional[bool] = None,\n",
    "        return_dict: Optional[bool] = None,\n",
    "    ):\n",
    "        r\"\"\"\n",
    "        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n",
    "            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n",
    "            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n",
    "            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n",
    "        \"\"\"\n",
    "        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
    "\n",
    "        transformer_outputs = self.model(\n",
    "            input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            position_ids=position_ids,\n",
    "            past_key_values=past_key_values,\n",
    "            inputs_embeds=inputs_embeds,\n",
    "            use_cache=use_cache,\n",
    "            output_attentions=output_attentions,\n",
    "            output_hidden_states=output_hidden_states,\n",
    "            return_dict=return_dict,\n",
    "        )\n",
    "        pooled_output = transformer_outputs[0][:, 0]\n",
    "        logits = self.score(pooled_output)\n",
    "\n",
    "        loss = None\n",
    "        if labels is not None:\n",
    "            if self.config.problem_type is None:\n",
    "                if self.num_labels == 1:\n",
    "                    self.config.problem_type = \"regression\"\n",
    "                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n",
    "                    self.config.problem_type = \"single_label_classification\"\n",
    "                else:\n",
    "                    self.config.problem_type = \"multi_label_classification\"\n",
    "\n",
    "            if self.config.problem_type == \"regression\":\n",
    "                loss_fct = MSELoss()\n",
    "                if self.num_labels == 1:\n",
    "                    loss = loss_fct(logits.squeeze(), labels.squeeze())\n",
    "                else:\n",
    "                    loss = loss_fct(logits, labels)\n",
    "            elif self.config.problem_type == \"single_label_classification\":\n",
    "                loss_fct = CrossEntropyLoss()\n",
    "                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n",
    "            elif self.config.problem_type == \"multi_label_classification\":\n",
    "                loss_fct = BCEWithLogitsLoss()\n",
    "                loss = loss_fct(logits, labels)\n",
    "        if not return_dict:\n",
    "            output = (logits,) + transformer_outputs[2:]\n",
    "            return ((loss,) + transformer_outputs) if loss is not None else output\n",
    "        \n",
    "        return SequenceClassifierOutputWithPast(\n",
    "            loss=loss,\n",
    "            logits=logits,\n",
    "            past_key_values=transformer_outputs.past_key_values,\n",
    "            hidden_states=transformer_outputs.hidden_states,\n",
    "            attentions=transformer_outputs.attentions,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3b30b8a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "57905bf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('mesolitica/malaysian-mistral-349M-MLM-512')\n",
    "# tokenizer.padding_side = 'left'\n",
    "config = AutoConfig.from_pretrained('mesolitica/malaysian-mistral-349M-MLM-512')\n",
    "config.problem_type = \"single_label_classification\"\n",
    "config.label2id = {'contradiction': 0, 'entailment': 1}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9398590c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation=\"flash_attention_2\"` instead.\n",
      "You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour\n",
      "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n",
      "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in MistralForSequenceClassification is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n",
      "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in MistralBiModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n",
      "Some weights of MistralForSequenceClassification were not initialized from the model checkpoint at mesolitica/malaysian-mistral-349M-MLM-512 and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "model = MistralForSequenceClassification.from_pretrained(\n",
    "    'mesolitica/malaysian-mistral-349M-MLM-512', \n",
    "    config = config,\n",
    "    use_flash_attention_2 = True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "46ec48b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainable_parameters = [param for param in model.parameters() if param.requires_grad]\n",
    "trainer = torch.optim.AdamW(trainable_parameters, lr = 2e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "db9478b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X, train_Y = [], []\n",
    "with open('shuffled-train.json') as fopen:\n",
    "    for l in fopen:\n",
    "        l = json.loads(l)\n",
    "        train_X.append(l['src'])\n",
    "        train_Y.append(l['label'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ce57e047",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "16037"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_X, test_Y = [], []\n",
    "with open('shuffled-test.json') as fopen:\n",
    "    for l in fopen:\n",
    "        l = json.loads(l)\n",
    "        test_X.append(l['src'])\n",
    "        test_Y.append(l['label'])\n",
    "        \n",
    "len(test_X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "67a4fcba",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 4\n",
    "epoch = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5eab2ca4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from accelerate import Accelerator\n",
    "accelerator = Accelerator(mixed_precision = 'bf16')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3a1b3101",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = accelerator.device\n",
    "model, trainer = accelerator.prepare(model, trainer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e5bf4f3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|█▎                                                           | 5665/273391 [07:47<5:38:35, 13.18it/s, loss=0.688]IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "  9%|█████▋                                                      | 25886/273391 [35:34<5:46:56, 11.89it/s, loss=0.467]IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      " 17%|█████████▊                                                | 46298/273391 [1:03:37<5:18:08, 11.90it/s, loss=0.323]IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      " 31%|█████████████████▊                                        | 84250/273391 [1:56:18<4:02:19, 13.01it/s, loss=0.891]IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      " 34%|███████████████████▏                                     | 91768/273391 [2:06:47<3:57:34, 12.74it/s, loss=0.0779]IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      " 41%|███████████████████████▏                                 | 110932/273391 [2:33:31<3:46:12, 11.97it/s, loss=0.481]IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      " 48%|███████████████████████████▋                              | 130325/273391 [3:00:41<3:33:38, 11.16it/s, loss=1.84]IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      " 65%|████████████████████████████████████▌                   | 178584/273391 [4:08:16<2:26:51, 10.76it/s, loss=0.0703]IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      " 73%|██████████████████████████████████████████▍               | 200312/273391 [4:38:42<1:36:37, 12.61it/s, loss=1.25]IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      " 79%|████████████████████████████████████████████▉            | 215532/273391 [5:00:05<1:13:16, 13.16it/s, loss=0.141]IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      " 95%|███████████████████████████████████████████████████████▉   | 259226/273391 [6:02:06<19:21, 12.20it/s, loss=0.427]IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "  2%|█▎                                                           | 5685/273391 [08:05<6:14:45, 11.91it/s, loss=0.918]IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "  9%|█████▌                                                      | 25555/273391 [36:21<5:35:48, 12.30it/s, loss=0.927]IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      " 11%|██████▍                                                     | 29273/273391 [41:39<6:59:49,  9.69it/s, loss=0.619]"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "best_dev_acc = -np.inf\n",
    "patient = 1\n",
    "current_patient = 0\n",
    "\n",
    "for e in range(epoch):\n",
    "    pbar = tqdm(range(0, len(train_X), batch_size))\n",
    "    losses = []\n",
    "    for i in pbar:\n",
    "        trainer.zero_grad()\n",
    "        x = train_X[i: i + batch_size]\n",
    "        y = np.array(train_Y[i: i + batch_size])\n",
    "        \n",
    "        padded = tokenizer(x, truncation = True, padding = True, return_tensors = 'pt', max_length = 2048)\n",
    "        padded['labels'] = torch.from_numpy(y)\n",
    "        for k in padded.keys():\n",
    "            padded[k] = padded[k].cuda()\n",
    "\n",
    "        padded.pop('token_type_ids', None)\n",
    "\n",
    "        o = model(**padded, use_cache = False)\n",
    "        loss = o.loss\n",
    "        \n",
    "        accelerator.backward(loss)\n",
    "\n",
    "        grad_norm = torch.nn.utils.clip_grad_norm_(trainable_parameters, 1.0)\n",
    "        trainer.step()\n",
    "            \n",
    "        losses.append(float(loss))\n",
    "        pbar.set_postfix(loss = float(loss))\n",
    "        \n",
    "        \n",
    "    dev_predicted = []\n",
    "    for i in range(0, len(test_X[:10000]), batch_size):\n",
    "        x = test_X[i: i + batch_size]\n",
    "        y = np.array(test_Y[i: i + batch_size])\n",
    "        padded = tokenizer(x, truncation = True, padding = True, return_tensors = 'pt', max_length = 2048)\n",
    "        padded['labels'] = torch.from_numpy(y)\n",
    "        for k in padded.keys():\n",
    "            padded[k] = padded[k].cuda()\n",
    "            \n",
    "        padded.pop('token_type_ids', None)\n",
    "        \n",
    "        o = model(**padded, use_cache = False)\n",
    "        loss = o.loss\n",
    "        pred = o.logits\n",
    "        \n",
    "        dev_predicted.append((pred.argmax(axis = 1).detach().cpu().numpy() == y).mean())\n",
    "        \n",
    "    dev_predicted = np.mean(dev_predicted)\n",
    "    \n",
    "    print(f'epoch: {e}, loss: {np.mean(losses)}, dev_predicted: {dev_predicted}')\n",
    "    \n",
    "    if dev_predicted >= best_dev_acc:\n",
    "        best_dev_acc = dev_predicted\n",
    "        current_patient = 0\n",
    "        model.save_pretrained('64M')\n",
    "    else:\n",
    "        current_patient += 1\n",
    "    \n",
    "    if current_patient >= patient:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "efe9c2e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wed Apr 24 12:08:10 2024       \r\n",
      "+---------------------------------------------------------------------------------------+\r\n",
      "| NVIDIA-SMI 545.23.06              Driver Version: 545.23.06    CUDA Version: 12.3     |\r\n",
      "|-----------------------------------------+----------------------+----------------------+\r\n",
      "| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |\r\n",
      "| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |\r\n",
      "|                                         |                      |               MIG M. |\r\n",
      "|=========================================+======================+======================|\r\n",
      "|   0  NVIDIA GeForce RTX 3090 Ti     On  | 00000000:01:00.0 Off |                  Off |\r\n",
      "| 46%   53C    P2             169W / 480W |  21145MiB / 24564MiB |      0%      Default |\r\n",
      "|                                         |                      |                  N/A |\r\n",
      "+-----------------------------------------+----------------------+----------------------+\r\n",
      "|   1  NVIDIA GeForce RTX 3090 Ti     On  | 00000000:08:00.0 Off |                  Off |\r\n",
      "|  0%   49C    P8              28W / 450W |      6MiB / 24564MiB |      0%      Default |\r\n",
      "|                                         |                      |                  N/A |\r\n",
      "+-----------------------------------------+----------------------+----------------------+\r\n",
      "                                                                                         \r\n",
      "+---------------------------------------------------------------------------------------+\r\n",
      "| Processes:                                                                            |\r\n",
      "|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |\r\n",
      "|        ID   ID                                                             Usage      |\r\n",
      "|=======================================================================================|\r\n",
      "|    0   N/A  N/A   2697189      C   /usr/bin/python3                          21136MiB |\r\n",
      "+---------------------------------------------------------------------------------------+\r\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "77e9313d",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████| 4010/4010 [01:20<00:00, 49.60it/s]\n"
     ]
    }
   ],
   "source": [
    "real_Y = []\n",
    "for i in tqdm(range(0, len(test_X), batch_size)):\n",
    "    x = test_X[i: i + batch_size]\n",
    "    y = np.array(test_Y[i: i + batch_size])\n",
    "    padded = tokenizer(x, padding = 'longest', return_tensors = 'pt')\n",
    "    padded['labels'] = torch.from_numpy(y)\n",
    "    for k in padded.keys():\n",
    "        padded[k] = padded[k].cuda()\n",
    "    \n",
    "    padded.pop('token_type_ids', None)\n",
    "    o = model(**padded, use_cache = False)\n",
    "    loss = o.loss\n",
    "    pred = o.logits\n",
    "    real_Y.extend(pred.argmax(axis = 1).detach().cpu().numpy().tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "0ddab41a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0    0.71505   0.77944   0.74586      7073\n",
      "           1    0.81266   0.75491   0.78272      8964\n",
      "\n",
      "    accuracy                        0.76573     16037\n",
      "   macro avg    0.76385   0.76718   0.76429     16037\n",
      "weighted avg    0.76961   0.76573   0.76646     16037\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn import metrics\n",
    "\n",
    "print(\n",
    "    metrics.classification_report(\n",
    "        real_Y, test_Y[:len(real_Y)],\n",
    "        digits = 5\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "209dfb59",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "16037"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(real_Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "1cad046d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "16037"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(test_Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8860002",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
